# coding=utf-8
# Copyright 2022 The ML Fairness Gym Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Plotting functions for the attention allocation example experiments."""

import copy
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# TODO() write tests for the plotting functions.

MEDIUM_FONTSIZE = 14
LARGE_FONTSIZE = 16


def create_dataframe_from_results(agent_names,
                                  value_to_report_dicts,
                                  separate_locations=False):
  """Turn json reports into a dataframe that can be used to plot the results.

  Args:
    agent_names: list of str names of each agent.
    value_to_report_dicts: list of dictionaries, where each dictionary
      corresponds to a set of experiments for each of the agents in agent_names.
      The dictionary maps a parameter value of dynamic factor to the json string
      report.
    separate_locations: boolean of whether or not to aggregate locations.

  Returns:
    A pandas dataframe.
  """

  pandas_df_data = []
  for agent_index in range(len(agent_names)):
    # TODO() clean up using pandas.read_json.
    # to make this code nicer.
    agent_name = agent_names[agent_index]
    # The keys of the value_to_report_dict should be the values of the varied
    # parameter, in this case, the dynamic rate.
    for value in value_to_report_dicts[agent_index].keys():
      df_row = {}
      df_row['agent_type'] = agent_name
      df_row['param_explored'] = 'dynamic factor'
      df_row['param_value'] = str(value)
      report = json.loads(value_to_report_dicts[agent_index][value])
      n_locations = report['env_params']['n_locations']

      discovered_incidents = np.array(report['metrics']['discovered_incidents'])
      missed_incidents = np.array(
          report['metrics']['occurred_incidents']) - np.array(
              report['metrics']['discovered_incidents'])
      discovered_over_occurred = np.array(
          report['metrics']['discovered_occurred_ratio'])
      discovered_over_occurred_weighted = np.array(
          report['metrics']['discovered_occurred_ratio_weighted'])

      if separate_locations:
        for location in range(n_locations):
          df_row_loc = copy.deepcopy(df_row)
          df_row_loc['location'] = location
          df_row_loc['discovered_incidents'] = discovered_incidents[location]
          df_row_loc['missed_incidents'] = missed_incidents[location]
          df_row_loc['discovered/occurred'] = discovered_over_occurred[location]
          df_row_loc[
              'discovered/occurred weighted'] = discovered_over_occurred_weighted[
                  location]
          pandas_df_data.append(df_row_loc)
      else:
        df_row['total_discovered'] = np.sum(discovered_incidents)
        df_row['total_missed'] = np.sum(missed_incidents)
        sorted_ratios = np.sort(discovered_over_occurred)
        df_row['discovered/occurred range'] = np.abs(sorted_ratios[-1] -
                                                     sorted_ratios[0])
        sorted_weighted_ratios = np.sort(discovered_over_occurred_weighted)
        df_row['discovered/occurred range weighted'] = np.abs(
            sorted_weighted_ratios[-1] - sorted_weighted_ratios[0])
        pandas_df_data.append(df_row)
  return pd.DataFrame(pandas_df_data)


def plot_occurence_action_single_dynamic(report, file_path=''):
  """Plot line charts of actions and incident occurences over time."""
  # History here is not a core.HistoryItem object. It is a list of lists
  # generated by attention_allocation_experiment's _get_relevant_history().
  history = report['metrics']['history']
  plt.figure(figsize=(16, 4))

  n_locations = report['env_params']['n_locations']

  action_data = np.asarray([item[1] for item in history])
  plt.subplot(1, 2, 1)
  for location in range(n_locations):
    plt.plot(action_data[:, location], label='loc=%d' % (location))
  plt.xlabel('Time steps', fontsize=16)
  plt.ylabel('Attention units allocated', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)

  incidents_data = np.asarray([item[0] for item in history])
  plt.subplot(1, 2, 2)
  for location in range(n_locations):
    plt.plot(incidents_data[:, location], label='loc=%d' % (location))
  plt.xlabel('Time steps', fontsize=16)
  plt.ylabel('Incidents occurred', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  plt.savefig(file_path + '.pdf', bbox_inches='tight')


def plot_discovered_missed_clusters(dataframe, file_path=''):
  """Plot location clusters comparing agents missed and discovered incidents."""
  plot_height = 5
  aspect_ratio = 1.3
  sns.catplot(
      x='param_value',
      y='missed_incidents',
      data=dataframe,
      hue='agent_type',
      palette='deep',
      height=plot_height,
      aspect=aspect_ratio,
      s=8,
      legend=False)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Missed incidents for each location', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.tight_layout()
  plt.savefig(file_path + '_missed.pdf')

  sns.catplot(
      x='param_value',
      y='discovered_incidents',
      data=dataframe,
      hue='agent_type',
      height=plot_height,
      aspect=aspect_ratio,
      s=8,
      legend=False)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Discovered incidents for each location', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.tight_layout()
  plt.savefig(file_path + '_discovered.pdf')


def plot_total_miss_discovered(dataframe, file_path=''):
  """Plot bar charts comparing agents total missed and discovered incidents."""
  plot_height = 5
  aspect_ratio = 1.3
  sns.set_style('whitegrid')
  sns.despine()

  sns.catplot(
      x='param_value',
      y='total_missed',
      data=dataframe,
      hue='agent_type',
      kind='bar',
      palette='muted',
      height=plot_height,
      aspect=aspect_ratio,
      legend=False)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Total missed incidents', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.savefig(file_path + '_missed.pdf', bbox_inches='tight')

  sns.catplot(
      x='param_value',
      y='total_discovered',
      data=dataframe,
      hue='agent_type',
      kind='bar',
      palette='muted',
      height=plot_height,
      aspect=aspect_ratio,
      legend=False)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Total discovered incidents', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.savefig(file_path + '_discovered.pdf', bbox_inches='tight')


def plot_discovered_occurred_ratio_range(dataframe, file_path=''):
  """Plot the range of discovered incidents/occurred range between locations."""
  plot_height = 5
  aspect_ratio = 1.3
  sns.set_style('whitegrid')
  sns.despine()

  sns.catplot(
      x='param_value',
      y='discovered/occurred range',
      data=dataframe,
      hue='agent_type',
      kind='bar',
      palette='muted',
      height=plot_height,
      aspect=aspect_ratio)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Discovered/occurred range', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  # plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.savefig(file_path + '.pdf', bbox_inches='tight')

  sns.catplot(
      x='param_value',
      y='discovered/occurred range weighted',
      data=dataframe,
      hue='agent_type',
      kind='bar',
      palette='muted',
      height=plot_height,
      aspect=aspect_ratio)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Delta', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  # plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.savefig(file_path + '_weighted.pdf', bbox_inches='tight')


def plot_discovered_occurred_ratio_locations(dataframe, file_path=''):
  """Plot the discovered incidents/occurred ratio for each location."""
  plot_height = 5
  aspect_ratio = 1.3
  sns.despine()

  sns.set(style='ticks')
  sns.catplot(
      x='param_value',
      y='discovered/occurred',
      data=dataframe,
      hue='agent_type',
      height=plot_height,
      aspect=aspect_ratio,
      s=8)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Discovered/occurred', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  # plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.savefig(file_path + '.pdf', bbox_inches='tight')

  sns.set(style='ticks')
  sns.catplot(
      x='param_value',
      y='discovered/occurred weighted',
      data=dataframe,
      hue='agent_type',
      height=plot_height,
      aspect=aspect_ratio,
      s=8)
  plt.xlabel('Dynamic factor', fontsize=LARGE_FONTSIZE)
  plt.ylabel('Discovered/occurred weighted', fontsize=LARGE_FONTSIZE)
  plt.xticks(fontsize=MEDIUM_FONTSIZE)
  plt.yticks(fontsize=MEDIUM_FONTSIZE)
  # plt.legend(fontsize=MEDIUM_FONTSIZE, title_fontsize=MEDIUM_FONTSIZE)
  plt.savefig(file_path + '_weighted.pdf', bbox_inches='tight')
